package drds.plus.sql_process.optimizer.pre_processor;

import drds.plus.common.jdbc.Parameters;
import drds.plus.sql_process.abstract_syntax_tree.ObjectCreateFactory;
import drds.plus.sql_process.abstract_syntax_tree.expression.NullValue;
import drds.plus.sql_process.abstract_syntax_tree.expression.bind_value.BindValue;
import drds.plus.sql_process.abstract_syntax_tree.expression.bind_value.SequenceValue;
import drds.plus.sql_process.abstract_syntax_tree.expression.item.Item;
import drds.plus.sql_process.abstract_syntax_tree.expression.item.column.Column;
import drds.plus.sql_process.abstract_syntax_tree.expression.item.function.*;
import drds.plus.sql_process.abstract_syntax_tree.expression.order_by.OrderBy;
import drds.plus.sql_process.abstract_syntax_tree.node.Node;
import drds.plus.sql_process.abstract_syntax_tree.node.dml.Dml;
import drds.plus.sql_process.abstract_syntax_tree.node.dml.Insert;
import drds.plus.sql_process.abstract_syntax_tree.node.dml.Update;
import drds.plus.sql_process.abstract_syntax_tree.node.query.Query;
import drds.plus.sql_process.optimizer.OptimizerException;

import java.util.List;
import java.util.Map;

/**
 * sqeuence的处理类
 *
 * <pre>
 * 1. 遍历下所有的节点，设置sequence.nextval的下标
 */
public class SequencePreProcessor {

    public static Node opitmize(Node node, Parameters parameters, Map<String, Object> extraCmd) throws OptimizerException {
        if (node instanceof Dml) {
            findDml((Dml) node, parameters);
        } else {
            findQuery((Query) node, parameters);
        }
        return node;
    }

    private static void findDml(Dml dml, Parameters parameters) {
        List<Item> columnNameList = dml.getColumnNameList();
        if (dml.isValueListList()) {
            if (dml.getValueListList() != null) {
                for (Object objectList : dml.getValueListList()) {
                    for (int i = 0; i < ((List) objectList).size(); i++) {
                        Object object = ((List) objectList).get(i);
                        if (dml.processAutoIncrement()) {
                            object = convertToSequence(columnNameList.get(i), object, parameters);
                            ((List) objectList).set(i, object);
                        }

                        findObject(object, parameters);
                    }
                }
            }
        } else {
            if (dml.getColumnValueList() != null) {
                for (int i = 0; i < dml.getColumnValueList().size(); i++) {
                    Object object = dml.getColumnValueList().get(i);
                    if (dml.processAutoIncrement()) {
                        object = convertToSequence(columnNameList.get(i), object, parameters);
                        dml.getColumnValueList().set(i, object);
                    }

                    findObject(object, parameters);
                }
            }
        }

        if (dml instanceof Insert) {
            Insert insertNode = (Insert) dml;
            if (insertNode.getDuplicateUpdateValueList() != null) {
                for (Object object : insertNode.getDuplicateUpdateValueList()) {
                    findObject(object, parameters);
                }
            }
        } else if (dml instanceof Update) {
            Update updateNode = (Update) dml;
            if (updateNode.getUpdateColumnValueList() != null) {
                for (Object object : updateNode.getUpdateColumnValueList()) {
                    findObject(object, parameters);
                }
            }
        }
    }

    public static Object convertToSequence(Item item, Object value, Parameters parameters) {
        if (item.isAutoIncrement() && !(value instanceof SequenceValue)) {
            // 处理自增
            if (value == null || value instanceof NullValue) {
                // 识别null值
                return ObjectCreateFactory.createSequenceValue(item.getTableName());
            } else if (value instanceof BindValue) {
                Object object = ((BindValue) value).copy();
                // 识别绑定变量
                if (parameters != null && parameters.isBatch()) {
                    Object assignmentedObject = ((BindValue) object).assignment(parameters.cloneByBatchIndex(0));
                    // 可能batch模式，还是返回bindVal对象
                    if (assignmentedObject instanceof BindValue) {
                        assignmentedObject = ((BindValue) assignmentedObject).getValue();
                    }

                    if (assignmentedObject == null || assignmentedObject instanceof NullValue) {
                        return ObjectCreateFactory.createSequenceValue(item.getTableName());
                    }
                } else {
                    Object assignmentedObject = ((BindValue) object).assignment(parameters);
                    // 可能batch模式，还是返回bindVal对象
                    if (assignmentedObject instanceof BindValue) {
                        assignmentedObject = ((BindValue) assignmentedObject).getValue();
                    }

                    if (assignmentedObject == null || assignmentedObject instanceof NullValue) {
                        return ObjectCreateFactory.createSequenceValue(item.getTableName());
                    }
                }
            }
        }

        return value;
    }


    private static void findQuery(Query query, Parameters parameters) {
        findFilter(query.getIndexQueryKeyFilter(), parameters);
        findFilter(query.getWhere(), parameters);
        findFilter(query.getResultFilter(), parameters);
        findFilter(query.getOtherJoinOnFilter(), parameters);
        findFilter(query.getHaving(), parameters);

        for (Item item : query.getSelectItemList()) {
            // 可能替换了subQuery
            findSelectItem(item, parameters);
        }

        if (query.getOrderByList() != null) {
            for (OrderBy orderBy : query.getOrderByList()) {
                findOrderBy(orderBy, parameters);
            }
        }

        if (query.getGroupByList() != null) {
            for (OrderBy groupBy : query.getGroupByList()) {
                findOrderBy(groupBy, parameters);
            }
        }
    }

    private static void findFilter(Filter filter, Parameters parameters) {
        if (filter == null) {
            return;
        }

        if (filter instanceof LogicalOperationFilter) {
            for (Filter filter1 : ((LogicalOperationFilter) filter).getFilterList()) {
                findFilter(filter1, parameters);
            }
        } else if (filter instanceof OrsFilter) {
            for (Filter filter1 : ((OrsFilter) filter).getFilterList()) {
                findFilter(filter1, parameters);
            }
        } else {
            findBooleanFilter((BooleanFilter) filter, parameters);
        }

    }

    private static void findBooleanFilter(BooleanFilter booleanFilter, Parameters parameters) {
        if (booleanFilter == null) {
            return;
        }

        findObject(booleanFilter.getColumn(), parameters);
        findObject(booleanFilter.getValue(), parameters);
        if (booleanFilter.getOperation() == Operation.in) {
            List<Object> valueList = booleanFilter.getValueList();
            if (valueList != null && !valueList.isEmpty()) {
                for (int i = 0; i < valueList.size(); i++) {
                    findObject(valueList.get(i), parameters);
                }
            }
        }

    }

    private static void findObject(Object object, Parameters parameters) {
        if (object instanceof Item) {
            findSelectItem((Item) object, parameters);
        } else if (object instanceof SequenceValue) {
            findSequenceValue((SequenceValue) object, parameters);
        } else if (object instanceof Query) { // scalar subquery
            // 深度优先,尝试递归找一下
            findQuery((Query) object, parameters);
        }
    }

    private static void findSelectItem(Item item, Parameters parameters) {
        if (item instanceof Filter) {
            findFilter((Filter) item, parameters);
        } else if (item instanceof Function) {
            findFunction((Function) item, parameters);
        } else if (item instanceof Column) {
            findColumn((Column) item, parameters);
        }
    }

    private static void findColumn(Column column, Parameters parameters) {
        // do nothing
    }

    private static void findOrderBy(OrderBy orderBy, Parameters parameters) {
        if (orderBy.getItem() instanceof Item) {
            findSelectItem(orderBy.getItem(), parameters);
        }
    }

    private static void findFunction(Function function, Parameters parameters) {
        for (Object arg : function.getArgList()) {
            findObject(arg, parameters);
        }
    }

    private static void findSequenceValue(SequenceValue sequenceValue, Parameters parameters) {
        int index = parameters.getFirstIndexToSetParameterMethodAndArgsMap().size() + parameters.getSequenceSize().incrementAndGet();
        sequenceValue.setOriginIndex(index);
    }

}
